{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "customInput": null,
        "customOutput": null,
        "originalKey": "f0af2d90-cb21-4ab4-b4cb-0fd00dbfb77b",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "4e15bfa2-5404-40d0-98b6-eb2732c8b72b",
        "showInput": false
      },
      "source": [
        "# Implicitron's config system"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "287be985-423d-42e0-a2af-1e8c585e723c",
        "showInput": false
      },
      "source": [
        "Implicitron's components are all based on a unified hierarchical configuration system. \n",
        "This allows configurable variables and all defaults to be defined separately for each new component.\n",
        "All configs relevant to an experiment are then automatically composed into a single configuration file that fully specifies the experiment.\n",
        "An especially important feature is extension points where users can insert their own sub-classes of Implicitron's base components.\n",
        "\n",
        "The file which defines this system is [here](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/implicitron/tools/config.py) in the PyTorch3D repo.\n",
        "The Implicitron volumes tutorial contains a simple example of using the config system.\n",
        "This tutorial provides detailed hands-on experience in using and modifying Implicitron's configurable components.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "fde300a2-99cb-4d52-9d5b-4464a2083e0b",
        "showInput": false
      },
      "source": [
        "## 0. Install and import modules\n",
        "\n",
        "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "customInput": null,
        "customOutput": null,
        "originalKey": "ad6e94a7-e114-43d3-b038-a5210c7d34c9",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "import torch\n",
        "import subprocess\n",
        "need_pytorch3d=False\n",
        "try:\n",
        "    import pytorch3d\n",
        "except ModuleNotFoundError:\n",
        "    need_pytorch3d=True\n",
        "if need_pytorch3d:\n",
        "    pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
        "    version_str=\"\".join([\n",
        "        f\"py3{sys.version_info.minor}_cu\",\n",
        "        torch.version.cuda.replace(\".\",\"\"),\n",
        "        f\"_pyt{pyt_version_str}\"\n",
        "    ])\n",
        "    !pip install iopath\n",
        "    if sys.platform.startswith(\"linux\"):\n",
        "        print(\"Trying to install wheel for PyTorch3D\")\n",
        "        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
        "        pip_list = !pip freeze\n",
        "        need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for  i in pip_list)\n",
        "    if need_pytorch3d:\n",
        "        print(f\"failed to find/install wheel for {version_str}\")\n",
        "if need_pytorch3d:\n",
        "    print(\"Installing PyTorch3D from source\")\n",
        "    !pip install ninja\n",
        "    !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "609896c0-9e2e-4716-b074-b565f0170e32",
        "showInput": false
      },
      "source": [
        "Ensure omegaconf is installed. If not, run this cell. (It should not be necessary to restart the runtime.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "customInput": null,
        "customOutput": null,
        "originalKey": "d1c1851e-b9f2-4236-93c3-19aa4d63041c",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "!pip install omegaconf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "code_folding": [],
        "collapsed": false,
        "customOutput": null,
        "executionStartTime": 1659465468717,
        "executionStopTime": 1659465468738,
        "hidden_ranges": [],
        "originalKey": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f",
        "requestMsgId": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f"
      },
      "outputs": [],
      "source": [
        "from dataclasses import dataclass\n",
        "from typing import Optional, Tuple\n",
        "\n",
        "import torch\n",
        "from omegaconf import DictConfig, OmegaConf\n",
        "from pytorch3d.implicitron.tools.config import (\n",
        "    Configurable,\n",
        "    ReplaceableBase,\n",
        "    expand_args_fields,\n",
        "    get_default_args,\n",
        "    registry,\n",
        "    run_auto_creation,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "a638bf90-eb6b-424d-b53d-eae11954a717",
        "showInput": false
      },
      "source": [
        "## 1. Introducing dataclasses \n",
        "\n",
        "[Type hints](https://docs.python.org/3/library/typing.html) give a taxonomy of types in Python. [Dataclasses](https://docs.python.org/3/library/dataclasses.html) let you create a class based on a list of members which have names, types and possibly default values. The `__init__` function is created automatically, and calls a `__post_init__` function if present as a final step. For example"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659454972732,
        "executionStopTime": 1659454972739,
        "originalKey": "71eaad5e-e198-492e-8610-24b0da9dd4ae",
        "requestMsgId": "71eaad5e-e198-492e-8610-24b0da9dd4ae",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "@dataclass\n",
        "class MyDataclass:\n",
        "    a: int\n",
        "    b: int = 8\n",
        "    c: Optional[Tuple[int, ...]] = None\n",
        "\n",
        "    def __post_init__(self):\n",
        "        print(f\"created with a = {self.a}\")\n",
        "        self.d = 2 * self.b"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659454973051,
        "executionStopTime": 1659454973077,
        "originalKey": "83202a18-a3d3-44ec-a62d-b3360a302645",
        "requestMsgId": "83202a18-a3d3-44ec-a62d-b3360a302645",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "my_dataclass_instance = MyDataclass(a=18)\n",
        "assert my_dataclass_instance.d == 16"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "b67ccb9f-dc6c-4994-9b99-b5a1bcfebd70",
        "showInput": false
      },
      "source": [
        "👷 Note that the `dataclass` decorator here is function which modifies the definition of the class itself.\n",
        "It runs immediately after the definition.\n",
        "Our config system requires that implicitron library code contains classes whose modified versions need to be aware of user-defined implementations.\n",
        "Therefore we need the modification of the class to be delayed. We don't use a decorator.\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "3e90f664-99df-4387-9c45-a1ad7939ef3a",
        "showInput": false
      },
      "source": [
        "## 2. Introducing omegaconf and OmegaConf.structured\n",
        "\n",
        "The [omegaconf](https://github.com/omry/omegaconf/) library provides a DictConfig class which is like a `dict` with str keys, but with extra features for ease-of-use as a configuration system."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659451341683,
        "executionStopTime": 1659451341690,
        "originalKey": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174",
        "requestMsgId": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "dc = DictConfig({\"a\": 2, \"b\": True, \"c\": None, \"d\": \"hello\"})\n",
        "assert dc.a == dc[\"a\"] == 2"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "3b5b76a9-4b76-4784-96ff-2a1212e48e48",
        "showInput": false
      },
      "source": [
        "OmegaConf has a serialization to and from yaml. The [Hydra](https://hydra.cc/) library relies on this for its configuration files."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659451411835,
        "executionStopTime": 1659451411936,
        "originalKey": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61",
        "requestMsgId": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "print(OmegaConf.to_yaml(dc))\n",
        "assert OmegaConf.create(OmegaConf.to_yaml(dc)) == dc"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "777fecdd-8bf6-4fd8-827b-cb8af5477fa8",
        "showInput": false
      },
      "source": [
        "OmegaConf.structured provides a DictConfig from a dataclass or instance of a dataclass. Unlike a normal DictConfig, it is type-checked and only known keys can be added."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659455098879,
        "executionStopTime": 1659455098900,
        "originalKey": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162",
        "requestMsgId": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "structured = OmegaConf.structured(MyDataclass)\n",
        "assert isinstance(structured, DictConfig)\n",
        "print(structured)\n",
        "print()\n",
        "print(OmegaConf.to_yaml(structured))"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "be4446da-e536-4139-9ba3-37669a5b5e61",
        "showInput": false
      },
      "source": [
        "`structured` knows it is missing a value for `a`."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "864811e8-1a75-4932-a85e-f681b0541ae9",
        "showInput": false
      },
      "source": [
        "Such an object has members compatible with the dataclass, so an initialisation can be performed as follows."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659455580491,
        "executionStopTime": 1659455580501,
        "originalKey": "eb88aaa0-c22f-4ffb-813a-ca957b490acb",
        "requestMsgId": "eb88aaa0-c22f-4ffb-813a-ca957b490acb",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "structured.a = 21\n",
        "my_dataclass_instance2 = MyDataclass(**structured)\n",
        "print(my_dataclass_instance2)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "2d08c81c-9d18-4de9-8464-0da2d89f94f3",
        "showInput": false
      },
      "source": [
        "You can also call OmegaConf.structured on an instance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659455594700,
        "executionStopTime": 1659455594737,
        "originalKey": "5e469bac-32a4-475d-9c09-8b64ba3f2155",
        "requestMsgId": "5e469bac-32a4-475d-9c09-8b64ba3f2155",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "structured_from_instance = OmegaConf.structured(my_dataclass_instance)\n",
        "my_dataclass_instance3 = MyDataclass(**structured_from_instance)\n",
        "print(my_dataclass_instance3)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659452594203,
        "executionStopTime": 1659452594333,
        "originalKey": "2ed559e3-8552-465a-938f-30c72a321184",
        "requestMsgId": "2ed559e3-8552-465a-938f-30c72a321184",
        "showInput": false
      },
      "source": [
        "## 3. Our approach to OmegaConf.structured\n",
        "\n",
        "We provide functions which are equivalent to `OmegaConf.structured` but support more features. \n",
        "To achieve the above using our functions, the following is used.\n",
        "Note that we indicate configurable classes using a special base class `Configurable`, not a decorator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659454053323,
        "executionStopTime": 1659454061629,
        "originalKey": "9888afbd-e617-4596-ab7a-fc1073f58656",
        "requestMsgId": "9888afbd-e617-4596-ab7a-fc1073f58656",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "class MyConfigurable(Configurable):\n",
        "    a: int\n",
        "    b: int = 8\n",
        "    c: Optional[Tuple[int, ...]] = None\n",
        "\n",
        "    def __post_init__(self):\n",
        "        print(f\"created with a = {self.a}\")\n",
        "        self.d = 2 * self.b"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659454784912,
        "executionStopTime": 1659454784928,
        "originalKey": "e43155b4-3da5-4df1-a2f5-da1d0369eec9",
        "requestMsgId": "e43155b4-3da5-4df1-a2f5-da1d0369eec9",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "# The expand_args_fields function modifies the class like @dataclasses.dataclass.\n",
        "# If it has not been called on a Configurable object before it has been instantiated, it will\n",
        "# be called automatically.\n",
        "expand_args_fields(MyConfigurable)\n",
        "my_configurable_instance = MyConfigurable(a=18)\n",
        "assert my_configurable_instance.d == 16"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659460669541,
        "executionStopTime": 1659460669566,
        "originalKey": "96eaae18-dce4-4ee1-b451-1466fea51b9f",
        "requestMsgId": "96eaae18-dce4-4ee1-b451-1466fea51b9f",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "# get_default_args also calls expand_args_fields automatically\n",
        "our_structured = get_default_args(MyConfigurable)\n",
        "assert isinstance(our_structured, DictConfig)\n",
        "print(OmegaConf.to_yaml(our_structured))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659460454020,
        "executionStopTime": 1659460454032,
        "originalKey": "359f7925-68de-42cd-bd34-79a099b1c210",
        "requestMsgId": "359f7925-68de-42cd-bd34-79a099b1c210",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "our_structured.a = 21\n",
        "print(MyConfigurable(**our_structured))"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659460599142,
        "executionStopTime": 1659460599149,
        "originalKey": "eac7d385-9365-4098-acf9-4f0a0dbdcb85",
        "requestMsgId": "eac7d385-9365-4098-acf9-4f0a0dbdcb85",
        "showInput": false
      },
      "source": [
        "## 4. First enhancement: nested types 🪺\n",
        "\n",
        "Our system allows Configurable classes to contain each other. \n",
        "One thing to remember: add a call to `run_auto_creation` in `__post_init__`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659465752418,
        "executionStopTime": 1659465752976,
        "originalKey": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a",
        "requestMsgId": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "class Inner(Configurable):\n",
        "    a: int = 8\n",
        "    b: bool = True\n",
        "    c: Tuple[int, ...] = (2, 3, 4, 6)\n",
        "\n",
        "\n",
        "class Outer(Configurable):\n",
        "    inner: Inner\n",
        "    x: str = \"hello\"\n",
        "    xx: bool = False\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659465762326,
        "executionStopTime": 1659465762339,
        "originalKey": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7",
        "requestMsgId": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "outer_dc = get_default_args(Outer)\n",
        "print(OmegaConf.to_yaml(outer_dc))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659465772894,
        "executionStopTime": 1659465772911,
        "originalKey": "0254204b-8c7a-4d40-bba6-5132185f63d7",
        "requestMsgId": "0254204b-8c7a-4d40-bba6-5132185f63d7",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "outer = Outer(**outer_dc)\n",
        "assert isinstance(outer, Outer)\n",
        "assert isinstance(outer.inner, Inner)\n",
        "print(vars(outer))\n",
        "print(outer.inner)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "44a78c13-ec92-4a87-808a-c4674b320c22",
        "showInput": false
      },
      "source": [
        "Note how inner_args is an extra member of outer. `run_auto_creation(self)` is equivalent to\n",
        "```\n",
        "    self.inner = Inner(**self.inner_args)\n",
        "```"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659461071129,
        "executionStopTime": 1659461071137,
        "originalKey": "af0ec78b-7888-4b0d-9346-63d970d43293",
        "requestMsgId": "af0ec78b-7888-4b0d-9346-63d970d43293",
        "showInput": false
      },
      "source": [
        "## 5. Second enhancement: pluggable/replaceable components 🔌\n",
        "\n",
        "If a class uses `ReplaceableBase` as a base class instead of `Configurable`, we call it a replaceable.\n",
        "It indicates that it is designed for child classes to use in its place.\n",
        "We might use `NotImplementedError` to indicate functionality which subclasses are expected to implement.\n",
        "The system maintains a global `registry` containing subclasses of each ReplaceableBase.\n",
        "The subclasses register themselves with it with a decorator.\n",
        "\n",
        "A configurable class (i.e. a class which uses our system, i.e. a child of `Configurable` or `ReplaceableBase`) which contains a ReplaceableBase must also \n",
        "contain a corresponding class_type field of type `str` which indicates which concrete child class to use."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659463453457,
        "executionStopTime": 1659463453467,
        "originalKey": "f2898703-d147-4394-978e-fc7f1f559395",
        "requestMsgId": "f2898703-d147-4394-978e-fc7f1f559395",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "class InnerBase(ReplaceableBase):\n",
        "    def say_something(self):\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class Inner1(InnerBase):\n",
        "    a: int = 1\n",
        "    b: str = \"h\"\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from an Inner1\")\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class Inner2(InnerBase):\n",
        "    a: int = 2\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from an Inner2\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659463453514,
        "executionStopTime": 1659463453592,
        "originalKey": "6f171599-51ee-440f-82d7-a59f84d24624",
        "requestMsgId": "6f171599-51ee-440f-82d7-a59f84d24624",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "class Out(Configurable):\n",
        "    inner: InnerBase\n",
        "    inner_class_type: str = \"Inner1\"\n",
        "    x: int = 19\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def talk(self):\n",
        "        self.inner.say_something()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659463191360,
        "executionStopTime": 1659463191428,
        "originalKey": "7abaecec-96e6-44df-8c8d-69c36a14b913",
        "requestMsgId": "7abaecec-96e6-44df-8c8d-69c36a14b913",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "Out_dc = get_default_args(Out)\n",
        "print(OmegaConf.to_yaml(Out_dc))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659463192717,
        "executionStopTime": 1659463192754,
        "originalKey": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c",
        "requestMsgId": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "Out_dc.inner_class_type = \"Inner2\"\n",
        "out = Out(**Out_dc)\n",
        "print(out.inner)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659463193751,
        "executionStopTime": 1659463193791,
        "originalKey": "aa0e1b04-963a-4724-81b7-5748b598b541",
        "requestMsgId": "aa0e1b04-963a-4724-81b7-5748b598b541",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "out.talk()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "4f78a56c-39cd-4563-a97e-041e5f360f6b",
        "showInput": false
      },
      "source": [
        "Note in this case there are many `args` members. It is usually fine to ignore them in the code. They are needed for the config."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659462145294,
        "executionStopTime": 1659462145307,
        "originalKey": "ce7069d5-a813-4286-a7cd-6ff40362105a",
        "requestMsgId": "ce7069d5-a813-4286-a7cd-6ff40362105a",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "print(vars(out))"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659462231114,
        "executionStopTime": 1659462231130,
        "originalKey": "c7f051ff-c264-4b89-80dc-36cf179aafaf",
        "requestMsgId": "c7f051ff-c264-4b89-80dc-36cf179aafaf",
        "showInput": false
      },
      "source": [
        "## 6. Example with torch.nn.Module  🔥\n",
        "Typically in implicitron, we use this system in combination with [`Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)s. \n",
        "Note in this case it is necessary to call `Module.__init__` explicitly in `__post_init__`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659462645018,
        "executionStopTime": 1659462645037,
        "originalKey": "42d210d6-09e0-4daf-8ccb-411d30f268f4",
        "requestMsgId": "42d210d6-09e0-4daf-8ccb-411d30f268f4",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "class MyLinear(torch.nn.Module, Configurable):\n",
        "    d_in: int = 2\n",
        "    d_out: int = 200\n",
        "\n",
        "    def __post_init__(self):\n",
        "        super().__init__()\n",
        "        self.linear = torch.nn.Linear(in_features=self.d_in, out_features=self.d_out)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.linear.forward(x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659462692309,
        "executionStopTime": 1659462692346,
        "originalKey": "546781fe-5b95-4e48-9cb5-34a634a31313",
        "requestMsgId": "546781fe-5b95-4e48-9cb5-34a634a31313",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "my_linear = MyLinear()\n",
        "input = torch.zeros(2)\n",
        "output = my_linear(input)\n",
        "print(\"output shape:\", output.shape)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659462738302,
        "executionStopTime": 1659462738419,
        "originalKey": "b6cb71e1-1d54-4e89-a422-0a70772c5c03",
        "requestMsgId": "b6cb71e1-1d54-4e89-a422-0a70772c5c03",
        "showInput": false
      },
      "source": [
        "`my_linear` has all the usual features of a Module.\n",
        "E.g. it can be saved and loaded with `torch.save` and `torch.load`.\n",
        "It has parameters:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659462821485,
        "executionStopTime": 1659462821501,
        "originalKey": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3",
        "requestMsgId": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "for name, value in my_linear.named_parameters():\n",
        "    print(name, value.shape)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659463222379,
        "executionStopTime": 1659463222409,
        "originalKey": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b",
        "requestMsgId": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b",
        "showInput": false
      },
      "source": [
        "## 7. Example of implementing your own pluggable component \n",
        "Let's say I am using a library with `Out` like in section **5** but I want to implement my own child of InnerBase. \n",
        "All I need to do is register its definition, but I need to do this before expand_args_fields is explicitly or implicitly called on Out."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659463694644,
        "executionStopTime": 1659463694653,
        "originalKey": "d9635511-a52b-43d5-8dae-d5c1a3dd9157",
        "requestMsgId": "d9635511-a52b-43d5-8dae-d5c1a3dd9157",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "@registry.register\n",
        "class UserImplementedInner(InnerBase):\n",
        "    a: int = 200\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from the user\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "f1511aa2-56b8-4ed0-a453-17e2bbfeefe7",
        "showInput": false
      },
      "source": [
        "At this point, we need to redefine the class Out. \n",
        "Otherwise if it has already been expanded without UserImplementedInner, then the following would not work,\n",
        "because the implementations known to a class are fixed when it is expanded.\n",
        "\n",
        "If you are running experiments from a script, the thing to remember here is that you must import your own modules, which register your own implementations,\n",
        "before you *use* the library classes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659463745967,
        "executionStopTime": 1659463745986,
        "originalKey": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406",
        "requestMsgId": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "class Out(Configurable):\n",
        "    inner: InnerBase\n",
        "    inner_class_type: str = \"Inner1\"\n",
        "    x: int = 19\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def talk(self):\n",
        "        self.inner.say_something()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659463747398,
        "executionStopTime": 1659463747431,
        "originalKey": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979",
        "requestMsgId": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "out2 = Out(inner_class_type=\"UserImplementedInner\")\n",
        "print(out2.inner)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659464033633,
        "executionStopTime": 1659464033643,
        "originalKey": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9",
        "requestMsgId": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9",
        "showInput": false
      },
      "source": [
        "## 8: Example of making a subcomponent pluggable\n",
        "\n",
        "Let's look what needs to happen if we have a subcomponent which we make pluggable, to allow users to supply their own."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659464709922,
        "executionStopTime": 1659464709933,
        "originalKey": "e37227b2-6897-4033-8560-9f2040abdeeb",
        "requestMsgId": "e37227b2-6897-4033-8560-9f2040abdeeb",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "class SubComponent(Configurable):\n",
        "    x: float = 0.25\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        return a + self.x\n",
        "\n",
        "\n",
        "class LargeComponent(Configurable):\n",
        "    repeats: int = 4\n",
        "    subcomponent: SubComponent\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        for _ in range(self.repeats):\n",
        "            a = self.subcomponent.apply(a)\n",
        "        return a"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659464710339,
        "executionStopTime": 1659464710459,
        "originalKey": "cab4c121-350e-443f-9a49-bd542a9735a2",
        "requestMsgId": "cab4c121-350e-443f-9a49-bd542a9735a2",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "large_component = LargeComponent()\n",
        "assert large_component.apply(3) == 4\n",
        "print(OmegaConf.to_yaml(LargeComponent))"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "customInput": null,
        "originalKey": "be60323a-badf-46e4-a259-72cae1391028",
        "showInput": false
      },
      "source": [
        "Made generic:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659464717226,
        "executionStopTime": 1659464717261,
        "originalKey": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8",
        "requestMsgId": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "class SubComponentBase(ReplaceableBase):\n",
        "    def apply(self, a: float) -> float:\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class SubComponent(SubComponentBase):\n",
        "    x: float = 0.25\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        return a + self.x\n",
        "\n",
        "\n",
        "class LargeComponent(Configurable):\n",
        "    repeats: int = 4\n",
        "    subcomponent: SubComponentBase\n",
        "    subcomponent_class_type: str = \"SubComponent\"\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        for _ in range(self.repeats):\n",
        "            a = self.subcomponent.apply(a)\n",
        "        return a"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659464725473,
        "executionStopTime": 1659464725587,
        "originalKey": "bbc3d321-6b49-4356-be75-1a173b1fc3a5",
        "requestMsgId": "bbc3d321-6b49-4356-be75-1a173b1fc3a5",
        "showInput": true
      },
      "outputs": [],
      "source": [
        "large_component = LargeComponent()\n",
        "assert large_component.apply(3) == 4\n",
        "print(OmegaConf.to_yaml(LargeComponent))"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659464672680,
        "executionStopTime": 1659464673231,
        "originalKey": "5115453a-1d96-4022-97e7-46433e6dcf60",
        "requestMsgId": "5115453a-1d96-4022-97e7-46433e6dcf60",
        "showInput": false
      },
      "source": [
        "The following things had to change:\n",
        "* The base class SubComponentBase was defined.\n",
        "* SubComponent gained a `@registry.register` decoration and had its base class changed to the new one.\n",
        "* `subcomponent_class_type` was added as a member of the outer class.\n",
        "* In any saved configuration yaml files, the key `subcomponent_args` had to be changed to `subcomponent_SubComponent_args`."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "customInput": null,
        "customOutput": null,
        "executionStartTime": 1659462041307,
        "executionStopTime": 1659462041637,
        "originalKey": "0739269e-5c0e-4551-b06f-f4aab386ba54",
        "requestMsgId": "0739269e-5c0e-4551-b06f-f4aab386ba54",
        "showInput": false
      },
      "source": [
        "## Appendix: gotchas ⚠️\n",
        "\n",
        "* Omitting to define `__post_init__` or not calling `run_auto_creation` in it.\n",
        "* Omitting a type annotation on a field. For example, writing \n",
        "```\n",
        "    subcomponent_class_type = \"SubComponent\"\n",
        "```\n",
        "instead of \n",
        "```\n",
        "    subcomponent_class_type: str = \"SubComponent\"\n",
        "```\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "bento_stylesheets": {
      "bento/extensions/flow/main.css": true,
      "bento/extensions/kernel_selector/main.css": true,
      "bento/extensions/kernel_ui/main.css": true,
      "bento/extensions/new_kernel/main.css": true,
      "bento/extensions/system_usage/main.css": true,
      "bento/extensions/theme/main.css": true
    },
    "captumWidgetMessage": {},
    "dataExplorerConfig": {},
    "kernelspec": {
      "display_name": "pytorch3d",
      "language": "python",
      "metadata": {
        "cinder_runtime": false,
        "fbpkg_supported": true,
        "is_prebuilt": true,
        "kernel_name": "bento_kernel_pytorch3d",
        "nightly_builds": true
      },
      "name": "bento_kernel_pytorch3d"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3"
    },
    "last_base_url": "https://9177.od.fbinfra.net:443/",
    "last_kernel_id": "90755407-3729-46f4-ab67-ff2cb1daa5cb",
    "last_msg_id": "f61034eb-826226915ad9548ffbe495ba_6317",
    "last_server_session_id": "d6b46f14-cee7-44c1-8c51-39a38a4ea4c2",
    "outputWidgetContext": {}
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
